import triton
import torch
from torch.utils.cpp_extension import load
import numpy as np
#import utils
from time import time

torch.manual_seed(0)

#torch.backends.cudnn.benchmark = True

configs = []

# Matrix multiplication
MNK = [
        (512, 512 ,512), 
        (2048, 2048, 2048),
        #(8192, 8192, 8192),
       
        (64, 64, 64000),
        (64, 64, 128000),
        (256, 256, 64000),
        (256, 256, 128000),

        (1536, 16, 1536),
        (1536, 32, 1536),
        (1536, 64, 1536),
        # (1536, 128, 1536),
        # (4096, 16, 4096),
        # (4096, 32, 4096),
        # (4096, 64, 4096),
        # (4096, 128, 4096),
    
        # (127008, 768, 576) 
      ]
for M, N, K in MNK:
    matmul = lambda a, b: torch.matmul(a, b)
    configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict(), None, None, None)]
#for M, N, K in MNK:
#    matmul = lambda a, b: torch.matmul(a.t(), b)
#    configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict(), None, None, None)]
#for M, N, K in MNK:
#    matmul = lambda a, b: torch.matmul(a, b.t())
#    configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict(), None, None, None)]

# Relative attention
NTHSE = [
          (16, 512, 1, 64, 64), 
        #  (16, 512, 1, 128, 128),
        #  (16, 512, 1, 256, 256),
        #  (16, 512, 1, 256, 512),
          (16, 512, 8, 64, 64), 
        #  (16, 512, 8, 128, 128),
        #  (16, 512, 8, 256, 256),
        #  (16, 512, 8, 256, 512),

        #  (64, 1024, 1, 64, 64), 
          (64, 1024, 1, 128, 128),
        #  (64, 1024, 1, 256, 256),
        #  (64, 1024, 1, 256, 512),
        #  (64, 1024, 8, 64, 64), 
          (64, 1024, 8, 128, 128),
        #  (64, 1024, 8, 256, 256),
        #  (64, 1024, 8, 256, 512),

        #  (128, 1024, 1, 64, 64), 
        #  (128, 1024, 1, 128, 128),
        #  (128, 1024, 1, 256, 256),
          (128, 1024, 1, 256, 512),
        #  (128, 1024, 8, 64, 64), 
        #  (128, 1024, 8, 128, 128),
        #  (128, 1024, 8, 256, 256),
          #(128, 1024, 8, 256, 512)
        ]
#for N, T, H, S, E in NTHSE:
#    configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE:
#    configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE:
#    configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict(), None, None, None)]

# 1D Dense convolution
NCHKR = [
        #(1, 1152, 12602, 512, 3)
        ]
for N, C, H, K, R in NCHKR:
    torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
    configs += [([N, C, H], 
                 [C, R, K], 
                 [N, K, H - R + 1], 
                 torch_fn, 
                 'nc(h+r),crk->nkh',
                 dict(), None, None, None)]

# 2D Dense convolution
NCHWKRS = [
          #(8, 64, 128, 128, 768, 3, 3),
          #(128, 3, 32, 32, 64, 3, 3),
          #(1, 1024, 32, 112, 112, 1024, 3, 3),
          #(8, 512, 32, 32, 1024, 3, 3)
          ]
for N, C, G, H, W, K, R, S in NCHWKRS:
    stride = 2
    torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), stride=stride, groups=G)
    P = (H - R + 1) // stride
    Q = (W - S + 1) // stride
    transform_a = lambda a: a.view(N, G, C // G, H, W)
    transform_b = lambda b: b.view(C // G, R, S, G, K // G)
    transform_c = lambda c: c.view(N, K, P, Q)
    configs += [([N, C, H, W], 
                  [C // G, R, S, K], 
                  [N, G, K // G, P, Q], 
                  torch_fn, 
                  'ngc(h*2+r)(w*2+s),crsgk->ngkhw',
                  dict(), transform_a, transform_b, transform_c)]


# 3D Dense Convolution
NCDHWKTRS = [
           #(8, 32, 27, 100, 100, 64, 3, 3, 3),
           #(8, 64, 23, 48, 48, 256, 3, 3, 3),
           #(8, 256, 19, 22, 22, 640, 3, 3, 3),
           #(8, 640, 15, 36, 36, 384, 3, 3, 3)
          ]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
    torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
    configs += [([N, C, D, H, W], 
                 [C, T, R, S, K], 
                 [N, K, D - T + 1, H - R + 1, W - R + 1], 
                 torch_fn, 
                 'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
                 dict(), None, None, None)]


# Shift convolution
shift_cuda = torch.utils.cpp_extension.load(
    'shift_cuda', ['kernels/shift_cuda.cpp', 
                   'kernels/shift_cuda_kernel.cu'],
    extra_cflags=['-O3'])
class shift(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, shift):
        ctx.save_for_backward(shift)
        return shift_cuda.forward(x, shift)

    @staticmethod
    def backward(ctx, grad_output):
        shift, = ctx.saved_tensors
        grad_output = shift_cuda.backward(grad_output, shift)

        return grad_output, None

NCHWKRS = [
          #(8, 64, 128, 128, 128, 3, 3),
          #(8, 128, 64, 64, 256, 3, 3),
          #(8, 256, 32, 32, 512, 3, 3),
          #(8, 512, 32, 32, 1024, 3, 3)
          ]
for N, C, H, W, K, R, S in NCHWKRS:
    shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
    shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
    def shift_conv(a, b, **kwargs):
        shift_h, shift_w = kwargs['sh'], kwargs['sw']
        shift_torch =  np.column_stack((shift_w*-1, shift_h*-1))
        shift_torch = torch.from_numpy(shift_torch).cuda()
        a = shift.apply(a, shift_torch)
        b = b.permute(1, 0)
        b = b.reshape(b.shape[0], b.shape[1], 1, 1)
        return torch.nn.functional.conv2d(a, b)
    configs += [([N, C, H, W], 
                  [C, K], 
                  [N, K, H, W], 
                  shift_conv, 
                  'nc(h + sh[c])(w + sw[c]),ck->nkhw',
                  {'sh': shift_h, 'sw': shift_w},
                  None, None, None)]

# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays, \
    transform_a, transform_b, transform_c in configs:
    dtype = torch.cuda.FloatTensor
    # initialize input tensors
    a = torch.rand(*a_shape).type(dtype).cuda()
    b = torch.rand(*b_shape).type(dtype).cuda()
    # reference output
    if torch_fn:
        rc = torch_fn(a, b, **arrays)
    else:
        rc = torch.einsum(expr, a, b)
    # triton output
    ta = a if transform_a is None else transform_a(a)
    tb = b if transform_b is None else transform_b(b)
    tc = torch.empty(c_shape, device=a.device)
    triton.ops.einsum(expr, ta, tb, tc, arrays = arrays, bench = True)
    ctx = triton.ops._einsum.registry[tc]
    tc = tc if transform_c is None else transform_c(tc)
    # performance relative to equivalent matrix multiplication
    B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
    cmp_eqbmm = True
    if cmp_eqbmm:
        a = torch.rand(B, M, K).type(dtype).cuda()
        b = torch.rand(B, K, N).type(dtype).cuda()
        c = torch.empty((B, M, N), device=a.device).cuda()
        tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, c, bench = True)
        ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms
        cmp_str = f'({ratio:4.2f})'
    else:
        cmp_str = ''
    # test and benchmark
    bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3
    diff = (tc - rc).abs().max() / rc.abs().max()
    print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20};          {bench:4.2f} {cmp_str};          {diff:4.2f}')
